Skip to content

MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B#18829

Draft
digantdesai wants to merge 5 commits intomainfrom
digantdesai/qwen35_moe
Draft

MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B#18829
digantdesai wants to merge 5 commits intomainfrom
digantdesai/qwen35_moe

Conversation

@digantdesai
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai commented Apr 11, 2026

Baseline Batched Speedup
Prefill (1341 tok) 588 tok/s 1807 tok/s 3.07x
Decode (128 tok) 90 tok/s 86 tok/s ~1.0x

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18829

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 5 New Failures, 2 Pending, 4 Unrelated Failures

As of commit 5055971 with merge base 266ff2d (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 11, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@digantdesai digantdesai force-pushed the digantdesai/qwen35_moe branch from a0d199a to 63548f5 Compare April 13, 2026 15:15
@digantdesai digantdesai changed the title Add CUDA sort shim for AOTI export (thrust-based sort_stable fallback) [AOTI-CUDA] MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B Apr 13, 2026
@digantdesai digantdesai changed the title [AOTI-CUDA] MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B Apr 13, 2026
Inductor emits aten::sort.stable for ops like argsort, but lacks a
native c-shim for it. This adds a thrust-based implementation
(aoti_torch_cuda_sort_stable) that handles int64, int32, and float32
dtypes on contiguous innermost-dim tensors. Registered as a supported
fallback kernel in CudaBackend so AOTI-compiled models can use sort.

This PR was authored with the assistance of Claude.
Sweeps prompt lengths [1..4095] with Qwen3.5-35B-A3B shapes (256 experts,
top-8, INT4 W4A16). Validates correctness against loop-based eager reference
at small M, benchmarks vectorized eager, torch.compile, and Triton fused_moe.
Handles OOM gracefully at large M where eager/compile dequantize all experts.

This PR was authored with the assistance of Claude.
When the Triton tile size fits within a single quantization group, load
one scale per N-element instead of per (K, N) element. Reduces scale
memory traffic in both GEMM1 and GEMM2 vec-mat kernels.

This PR was authored with the assistance of Claude.
Adds a batched (M>1) Triton fused MoE kernel using tensor-core mma
instructions for prefill workloads. Includes moe_align_block_size for
token-expert sorting and scale broadcast optimization in the batched
GEMM inner loops.

Weight layout: [E, N, K//2] (packed INT4).

This PR was authored with the assistance of Claude.
Add use_batched_moe flag on FusedMoEExperts, toggled by _set_batched_moe
in export.py before each method's torch.export call. Decode (T=1) uses
the vec-mat fused_moe kernel; prefill (T>=2) uses fused_moe_batched_gemm.

This PR was authored with the assistance of Claude.
@digantdesai digantdesai force-pushed the digantdesai/qwen35_moe branch from 63548f5 to 5055971 Compare April 13, 2026 19:46
@digantdesai
Copy link
Copy Markdown
Contributor Author

@claude review this PR.

@claude
Copy link
Copy Markdown

claude bot commented Apr 13, 2026

Claude finished @digantdesai's task in 2m 33s —— View job


PR Review: MoE prefill bf16 perf improvement for qwen-3.5-35B-A3B

  • Read and analyze all changed files
  • Review kernel implementation (fused_moe.py)
  • Review CUDA sort shim (sort.cu, sort.h)
  • Review benchmarks and tests
  • Review export/model changes
  • Review build system changes
  • Post comprehensive review feedback

This PR adds a batched tensor-core MoE kernel for prefill (fused_moe_batched_gemm) alongside the existing vec-mat decode kernel, achieving a 3x speedup on prefill. The changes are well-structured across 5 commits with clear separation of concerns. Here's a detailed review:

Overall Architecture

The approach is sound: use the existing fused_moe vec-mat kernel for decode (M=1, bandwidth-bound) and a new fused_moe_batched_gemm kernel with token sorting + tl.dot for prefill (M>>1, compute-bound). The export pipeline toggles between them via use_batched_moe flag.


fused_moe.py — Batched Kernel + Token Sorting

Positive:

  • The moe_align_block_size implementation avoids data-dependent shapes, making it compatible with torch.export. Using argsort on constructed unique keys (flat_ids * num_pairs + arange) instead of stable=True sort is a pragmatic workaround.
  • The scale loading optimization (BLOCK_SIZE_K <= group_size branch) avoids unnecessary per-element scale indexing when the entire tile shares one scale group. Applied consistently to all four kernels.
  • _BATCHED_BLOCK_M = 16 is well-justified as the minimum for tl.dot with minimal padding waste.

Issues and suggestions:

  1. Potential race condition in GEMM2 scatter (fused_moe.py:830-834): In _fused_moe_silu_batched_kernel, multiple M-blocks scatter their output to original pair positions via pair_ids. If two different expert blocks happen to have the same pair_id values (shouldn't happen by construction of moe_align_block_size since each pair is assigned to exactly one expert), this would be fine. However, sentinel rows all write to out_buf[num_pairs] — multiple blocks may write to the same sentinel row concurrently. Since sentinel row data is discarded (out_buf[:num_pairs]), this is benign but worth a comment.

  2. moe_align_block_size worst-case allocation (fused_moe.py:483): max_num_tokens_padded = num_pairs + num_experts * block_size. For Qwen3.5 (256 experts, block_size=16, 1341 tokens * top_k=8 = 10728 pairs), this is 10728 + 256*16 = 14824 — reasonable. But the intermediate sorted output cache1 is allocated at max_padded rows (fused_moe.py:879-884), which could be up to 14824 rows even though the actual padded count may be smaller. Consider using num_tokens_post_padded for the grid calculation and documenting why the allocation uses the worst-case bound.

  3. fused_moe_batched convenience wrapper (fused_moe.py:970-994): This wrapper is identical to calling fused_moe_batched_gemm directly. If it only exists for benchmarking, it could be removed and the benchmark could call fused_moe_batched_gemm directly. Fix this →

  4. Grid size for batched GEMM2 (fused_moe.py:922): The grid uses num_expert_blocks which is max_num_tokens_padded // BLOCK_M. For blocks beyond num_tokens_post_padded, the kernel processes sentinel rows that produce zero output (weights are zeroed). This works but launches potentially wasted CTAs. For the Qwen3.5 dimensions this is minor (extra ~256 blocks at most), but for models with more experts relative to tokens, it could matter.

  5. out_buf is zero-initialized (fused_moe.py:914-919): The scatter-based GEMM2 writes to arbitrary positions, so torch.zeros is needed (unused positions stay zero for the sum). This is correct but the cache1 for GEMM1 uses torch.empty which is also fine since every position in max_padded gets written in sorted order. Good.


sort.cu — CUDA Sort Shim

Positive:

  • Handles non-innermost dimension sorting via transpose → sort → transpose-back, which is the standard approach.
  • Supports multiple dtypes including bf16 and fp16.
  • Error handling with ET_CHECK_OR_RETURN_ERROR is thorough.

Issues and suggestions:

  1. Sequential slice sorting (sort.cu:345-403): The for (int64_t s = 0; s < num_slices; ++s) loop launches separate thrust sorts per slice. For tensors with many small slices (e.g., [1000, 8] sorted on dim=-1), this means 1000 sequential kernel launches. CUB's segmented sort (cub::DeviceSegmentedSort) would be more efficient for this case — but this is a known limitation of thrust-based implementations and may not be needed for the current use case.

  2. cudaMallocAsync/cudaFreeAsync in transpose path (sort.cu:299-306, 431-432): These allocate temporary buffers per call. For repeated sort operations (e.g., in a loop), a workspace allocator would be more efficient. For the MoE use case where argsort is called once per forward pass, this is fine.

  3. Missing elem_size == 1 case in launch_permute (sort.cu:80-108): The switch statement handles 8, 4, and 2 byte elements but silently does nothing for 1-byte types. Since the supported dtypes are all ≥2 bytes this is fine, but a default case with an error would be defensive. Fix this →


Tests

Positive:

  • test_batched_correctness covers multiple M values (8, 16, 32, 64, 128) with varying expert counts and group sizes.
  • test_batched_matches_fused validates that the batched kernel matches the existing decode kernel at full Qwen-scale dimensions (E=256, K=2048).
  • TestMoeAlignBlockSize has good coverage of edge cases (single token, all-same-expert, sentinel values, determinism).
  • Sort shim tests verify both export success and fallback kernel registration.

Suggestions:

  1. Missing batched kernel test for M=1 (test_fused_moe.py:339): The batched kernel should gracefully handle M=1 (single token during prefill), but there's no test for this. It would be padded to BLOCK_M=16 with 15 sentinels — worth testing.

  2. Tolerance of 0.5 abs error in test_batched_matches_fused (test_fused_moe.py:425): This seems high. The two kernels operate on the same quantized weights, so differences come from fp accumulation order. Consider also checking relative error like the other tests do. The batched kernel uses tl.dot (tensor cores) while decode uses element-wise multiply+reduce, which can give different rounding — 0.5 abs may be appropriate but worth a comment explaining why.


Export / Model Changes

Positive:

  • The _set_batched_moe toggle pattern is clean — trace decode with use_batched_moe=False, trace prefill with use_batched_moe=True.
  • Using example_seq_len = config.max_seq_len - 1 for the prefill example is correct to ensure AOTI compiles kernels for the full dynamic range.

Suggestions:

  1. Prefill example allocation (export.py:440-441): torch.zeros((1, example_seq_len), dtype=torch.long) allocates a tensor of size max_seq_len-1 (default 4095) just as a tracing example. This is fine since it's CPU-only during export, but for large max_seq_len values it could be notable.

Build System

  1. CMakeLists.txt (backends/cuda/CMakeLists.txt:112-114): sort.cu is correctly gated behind CMAKE_CUDA_COMPILER. The TARGETS file also adds both sort.cu and sort.h appropriately.

  2. cuda_backend.py:148: Adding "at::_ops::sort_stable::call" to get_supported_fallback_kernels is the correct registration for the sort shim.


Summary

Strong PR with a well-motivated performance optimization. The batched tensor-core kernel + token sorting approach is standard for MoE prefill and the implementation is solid. The scale loading optimization is a nice touch. The CUDA sort shim fills a real gap for AOTI export.

Key items to consider:

  • Item 3: fused_moe_batched wrapper could be consolidated
  • Item 8: Add a default/error case in launch_permute's switch
  • Item 9: Add a test for batched kernel with M=1
  • Item 10: Consider adding a comment explaining the 0.5 tolerance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant